# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
#
# SPDX-License-Identifier: Apache-2.0

from concurrent.futures import ThreadPoolExecutor
from typing import Any, Optional

from haystack import Document, component, default_from_dict, default_to_dict
from haystack.components.embedders.types.protocol import TextEmbedder
from haystack.components.retrievers.types import EmbeddingRetriever
from haystack.core.serialization import component_to_dict
from haystack.utils.deserialization import deserialize_component_inplace


@component
class MultiQueryEmbeddingRetriever:
    """
    A component that retrieves documents using multiple queries in parallel with an embedding-based retriever.

    This component takes a list of text queries, converts them to embeddings using a query embedder,
    and then uses an embedding-based retriever to find relevant documents for each query in parallel.
    The results are combined and sorted by relevance score.

    ### Usage example

    ```python
    from haystack import Document
    from haystack.document_stores.in_memory import InMemoryDocumentStore
    from haystack.document_stores.types import DuplicatePolicy
    from haystack.components.embedders import SentenceTransformersTextEmbedder
    from haystack.components.embedders import SentenceTransformersDocumentEmbedder
    from haystack.components.retrievers import InMemoryEmbeddingRetriever
    from haystack.components.writers import DocumentWriter
    from haystack.components.retrievers import MultiQueryEmbeddingRetriever

    documents = [
        Document(content="Renewable energy is energy that is collected from renewable resources."),
        Document(content="Solar energy is a type of green energy that is harnessed from the sun."),
        Document(content="Wind energy is another type of green energy that is generated by wind turbines."),
        Document(content="Geothermal energy is heat that comes from the sub-surface of the earth."),
        Document(content="Biomass energy is produced from organic materials, such as plant and animal waste."),
        Document(content="Fossil fuels, such as coal, oil, and natural gas, are non-renewable energy sources."),
    ]

    # Populate the document store
    doc_store = InMemoryDocumentStore()
    doc_embedder = SentenceTransformersDocumentEmbedder(model="sentence-transformers/all-MiniLM-L6-v2")
    doc_embedder.warm_up()
    doc_writer = DocumentWriter(document_store=doc_store, policy=DuplicatePolicy.SKIP)
    documents = doc_embedder.run(documents)["documents"]
    doc_writer.run(documents=documents)

    # Run the multi-query retriever
    in_memory_retriever = InMemoryEmbeddingRetriever(document_store=doc_store, top_k=1)
    query_embedder = SentenceTransformersTextEmbedder(model="sentence-transformers/all-MiniLM-L6-v2")

    multi_query_retriever = MultiQueryEmbeddingRetriever(
        retriever=in_memory_retriever,
        query_embedder=query_embedder,
        max_workers=3
    )

    queries = ["Geothermal energy", "natural gas", "turbines"]
    result = multi_query_retriever.run(queries=queries)
    for doc in result["documents"]:
        print(f"Content: {doc.content}, Score: {doc.score}")
    >> Content: Geothermal energy is heat that comes from the sub-surface of the earth., Score: 0.8509603046266574
    >> Content: Renewable energy is energy that is collected from renewable resources., Score: 0.42763211298893034
    >> Content: Solar energy is a type of green energy that is harnessed from the sun., Score: 0.40077417016494354
    >> Content: Fossil fuels, such as coal, oil, and natural gas, are non-renewable energy sources., Score: 0.3774863680
    >> Content: Wind energy is another type of green energy that is generated by wind turbines., Score: 0.30914239725622
    >> Content: Biomass energy is produced from organic materials, such as plant and animal waste., Score: 0.25173074243
    ```
    """  # noqa E501

    def __init__(self, *, retriever: EmbeddingRetriever, query_embedder: TextEmbedder, max_workers: int = 3) -> None:
        """
        Initialize MultiQueryEmbeddingRetriever.

        :param retriever: The embedding-based retriever to use for document retrieval.
        :param query_embedder: The query embedder to convert text queries to embeddings.
        :param max_workers: Maximum number of worker threads for parallel processing.
        """
        self.retriever = retriever
        self.query_embedder = query_embedder
        self.max_workers = max_workers
        self._is_warmed_up = False

    def warm_up(self) -> None:
        """
        Warm up the query embedder and the retriever if any has a warm_up method.
        """
        if not self._is_warmed_up:
            if hasattr(self.query_embedder, "warm_up") and callable(getattr(self.query_embedder, "warm_up")):
                self.query_embedder.warm_up()
            if hasattr(self.retriever, "warm_up") and callable(getattr(self.retriever, "warm_up")):
                self.retriever.warm_up()
            self._is_warmed_up = True

    @component.output_types(documents=list[Document])
    def run(self, queries: list[str], retriever_kwargs: Optional[dict[str, Any]] = None) -> dict[str, list[Document]]:
        """
        Retrieve documents using multiple queries in parallel.

        :param queries: List of text queries to process.
        :param retriever_kwargs: Optional dictionary of arguments to pass to the retriever's run method.
        :returns:
            A dictionary containing:
                - `documents`: List of retrieved documents sorted by relevance score.
        """
        docs: list[Document] = []
        seen_contents = set()
        retriever_kwargs = retriever_kwargs or {}

        if not self._is_warmed_up:
            self.warm_up()

        with ThreadPoolExecutor(max_workers=self.max_workers) as executor:
            queries_results = executor.map(lambda query: self._run_on_thread(query, retriever_kwargs), queries)
            for result in queries_results:
                if not result:
                    continue
                for doc in result:
                    # deduplicate based on content
                    if doc.content not in seen_contents:
                        docs.append(doc)
                        seen_contents.add(doc.content)

        docs.sort(key=lambda x: x.score or 0.0, reverse=True)
        return {"documents": docs}

    def _run_on_thread(self, query: str, retriever_kwargs: Optional[dict[str, Any]] = None) -> Optional[list[Document]]:
        """
        Process a single query on a separate thread.

        :param query: The text query to process.
        :returns:
            List of retrieved documents or None if no results.
        """
        embedding_result = self.query_embedder.run(text=query)
        query_embedding = embedding_result["embedding"]
        result = self.retriever.run(query_embedding=query_embedding, **(retriever_kwargs or {}))
        if result and "documents" in result:
            return result["documents"]
        return None

    def to_dict(self) -> dict[str, Any]:
        """
        Serializes the component to a dictionary.

        :returns:
            A dictionary representing the serialized component.
        """
        return default_to_dict(
            self,
            retriever=component_to_dict(obj=self.retriever, name="retriever"),
            query_embedder=component_to_dict(obj=self.query_embedder, name="query_embedder"),
            max_workers=self.max_workers,
        )

    @classmethod
    def from_dict(cls, data: dict[str, Any]) -> "MultiQueryEmbeddingRetriever":
        """
        Deserializes the component from a dictionary.

        :param data: The dictionary to deserialize from.
        :returns:
            The deserialized component.
        """
        deserialize_component_inplace(data["init_parameters"], key="retriever")
        deserialize_component_inplace(data["init_parameters"], key="query_embedder")
        return default_from_dict(cls, data)
